import cv2
import numpy as np
from fatigue_algo.scrfd import scrfd
from fatigue_algo.landmark_mbv2.onnx.detector import Detector
from fatigue_algo.fatigue_detector.head_detector import HeadDetector_real
from service.detect_handler import DetectHandler
from fatigue_algo.yaml_load import load_yaml
from fatigue_algo.fatigue_detector.eye_Detector import EyeDetector_real
import unittest

scrfd_detector = scrfd.SCRFD("scrfd_500m_kps.onnx")  # 人脸检测模型
landmark_detector = Detector()
headDetector = HeadDetector_real()
cfg = load_yaml("blink_validate_config.yaml")

class MyTest(unittest.TestCase):

    def test_get_z_shift(self):
        img = cv2.imread("service/front_face.png")
        dets = scrfd_detector.detect_faces(img)[0]
        if(len(dets) > 0):
            landmark = landmark_detector.detect_landmarks(img,[dets])
            headDetector.get_pitch_roll_yaw_x_y_zShift(landmark_detector,img,dets,landmark)

    def test_z_shift_with_camera(self):
        camera = cv2.VideoCapture(0)
        while(camera.isOpened()):
            ret,frame = camera.read()
            if(ret):
                dets = scrfd_detector.detect_faces(frame)[0]
                if (len(dets) > 0):
                    det = dets[0]
                    landmark = landmark_detector.detect_landmarks(frame, [[det]])
                    pitch,roll,yaw,x,y,z = headDetector.get_pitch_roll_yaw_x_y_zShift(landmark_detector, frame, dets, landmark)
                    cv2.putText(frame,str(x),(60, 130), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
                    cv2.putText(frame,str(y),(60, 200), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
                    cv2.putText(frame,str(z),(60, 270), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
                    cv2.imshow("res",frame)
            key = cv2.waitKey(10)
            if(key == 27):
                break
        cv2.destroyAllWindows()
        camera.release()

    def test_nod_backforth_rotate_with_camera(self):
        camera = cv2.VideoCapture(0,cv2.CAP_DSHOW)
        while(camera.isOpened()):
            ret,frame = camera.read()
            if(ret):
                dets = scrfd_detector.detect_faces(frame)[0]
                if (len(dets) > 0):
                    det = dets[0]
                    landmark = landmark_detector.detect_landmarks(frame, [[det]])
                    headDetector.setHeadPose_params(nod_threshold=2, backForth_threshold=0.2, rotate_threshold=3)
                    nod_flag, rotate_flag,backforth_flag = headDetector.head_action_detect(landmark_detector,frame,det,landmark)
                    nod_state = "nodding" if nod_flag == 1 else ""
                    backforth_state = "backforth" if backforth_flag == 1 else ""
                    rotate_state = "rotating" if rotate_flag == 1 else ""

                    pitch, roll, yaw, x, y, z = headDetector.get_pitch_roll_yaw_x_y_zShift(landmark_detector, frame,
                                                                                           dets, landmark)
                    # cv2.putText(frame, str(pitch), (60, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
                    # cv2.putText(frame, str(yaw), (60, 120), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
                    # cv2.putText(frame, str(roll), (60, 180), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
                    cv2.putText(frame, str(x), (60, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
                    cv2.putText(frame, str(y), (60, 120), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
                    cv2.putText(frame, str(z), (60, 180), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)

                    cv2.putText(frame, str(nod_state), (60, 240), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
                    cv2.putText(frame, str(backforth_state), (60, 300), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
                    cv2.putText(frame, str(rotate_state), (60, 360), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
                    cv2.imshow("res",frame)
            key = cv2.waitKey(10)
            if(key == 27):
                break
        cv2.destroyAllWindows()
        camera.release()

    '''测试眨眼检测方法：blink_detect'''
    def test_blink_func_selected(self):
        checkboxs = ['blink']
        eyeDetector = EyeDetector_real()
        detector = DetectHandler(cfg,checkboxs,scrfd_detector,landmark_detector,None)
        basic_info = dict()

        camera = cv2.VideoCapture(0,cv2.CAP_DSHOW)
        while (camera.isOpened()):
            ret, frame = camera.read()
            if (ret):
                dets = scrfd_detector.detect_faces(frame)[0]
                if (len(dets) > 0):
                    det = dets[0]
                    landmark = landmark_detector.detect_landmarks(frame, [[det]])
                    pitch, roll, yaw = headDetector.get_pitch_roll_yaw(landmark_detector, frame, det,
                                                                            landmark)  # 头部姿态
                    basic_info['frame'] = frame
                    basic_info['landmark'] = landmark
                    basic_info['rotate_vector'] = (pitch, roll, yaw)
                    detector.basic_info = basic_info
                    position = (60,60)
                    detector.blink_detect(checkboxs[0], eyeDetector,position)

            key = cv2.waitKey(10)
            cv2.imshow("res",frame)
            if (key == 27):
                break
        cv2.destroyAllWindows()
        camera.release()

    def test_read_EAR_B(self):
        file = open("EAR_B.txt", encoding='utf-8')
        EAR_list = []
        for line in file:
            EAR_list.append(np.float64(line))
        mean = np.mean(EAR_list)
        print(f"EAR_mean = {mean}")